# Getting Started

After `python3 -m venv venv` to create a virtual environment, `source venv/bin/activate` to activate it, run `pip install -r requirements.txt` to install the required pip packages.
The underlying dataset is ImageNet1K (ILSVRC 2012), which can be retrieved from https://www.image-net.org/.


# Training Scripts

For training HydraViT run the following command:
```
./distributed_train.sh <number of GPUs> <path/to/imagenet> --config config.yaml --model deit_base_patch16_224_half --workers 8 --epochs 300 --drop-p 1
```
This will train HydraViT with 3 to 12 attention heads with uniform subnetwork sampling. Note that this requires 2 GPUs equivalent to an A100. 
Adding the flag `--tsb` enables training with 3 subnetworks, specifically 3, 6 and 12 heads.
Adding the flag `--weighted` enables weighted subnetwork sampling, which currently has to be adjusted in `train.py`.

# Validation Scripts

For validating the accuracy run the following command:
```
python3 validate.py <path/to/imagenet/val> --model deit_dyn_patch16_224 --checkpoint <path/to/checkpoint> -b <batch size> --use-ema --drop-p <embedding dimension>
```
This will evaluate the checkpoint at the chosen embedding dimension, where 192 corresponds to 3 heads or the architecture of DeiT-tiny, 384 corresponds to 6 heads or DeiT-small, and 768 corresponds to 12 heads or DeiT-base.

For validating throughput run:
```
python3 validate_throughput.py <path/to/imagenet/val> --model deit_dyn_patch16_224 -b 512 --drop-p <embedding dimension>
```

For validating MACs run (this requires `pip install deepspeed`):
```
python3 validate_macs.py <path/to/imagenet/val> --model deit_dyn_patch16_224 -b 1 --drop-p <embedding dimension>
```

For validating RAM usage and model parameters run (this requires `pip install torchinfo`):
```
python3 validate_memory.py <path/to/imagenet/val> --model deit_dyn_patch16_224 --drop-p <embedding dimension>
```